{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Classification Trees"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Gini Impurity:\n",
    "\n",
    "$$L(\\mathcal{R}_m) = \\sum_{k = 1}^K \\hat{p}_{mk}(1-\\hat{p}_{mk})$$\n",
    "\n",
    "Cross entropy:\n",
    "\n",
    "$$L(\\mathcal{R}_m) = -\\sum_{k = 1}^K \\hat{p}_{mk}\\log_2(\\hat{p}_{mk}) $$\n",
    "\n",
    "Minimize weighted loss of children \n",
    "\n",
    "$$\n",
    "\\frac{|\\mathcal{R}_{c1}|L(\\mathcal{R}_{c1}) + |\\mathcal{R}_{c2}|L(\\mathcal{R}_{c2})}{|\\mathcal{R}_{c1}| + |\\mathcal{R}_{c2}|}\n",
    "$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Note: one cheat--using `combinations` from `itertools`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "from itertools import combinations\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "metadata": {},
   "outputs": [],
   "source": [
    "penguins = sns.load_dataset('penguins')\n",
    "penguins.dropna(inplace = True)\n",
    "X = np.array(penguins.drop(columns = 'species'))\n",
    "y = np.array(penguins['species'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gini_impurity(y):\n",
    "    size = len(y)\n",
    "    classes, counts = np.unique(y, return_counts = True)\n",
    "    pmk = counts/size\n",
    "    return np.sum(pmk*(1-pmk))\n",
    "     \n",
    "def cross_entropy(y):\n",
    "    size = len(y)\n",
    "    classes, counts = np.unique(y, return_counts = True)\n",
    "    pmk = counts/size\n",
    "    return -np.sum(pmk*np.log2(pmk))\n",
    "\n",
    "def split_loss(child1, child2, loss = cross_entropy):\n",
    "    return (len(child1)*loss(child1) + len(child2)*loss(child2))/(len(child1) + len(child2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "tags": [
     "hide-input"
    ]
   },
   "outputs": [],
   "source": [
    "def all_possible_splits(x):\n",
    "    indices = []\n",
    "    for i in range(1, 2**len(x)-1):\n",
    "        list_ = [bool(int(j)) for j in bin(i)[2:]]\n",
    "        falses = [False for j in range(len(x) - len(list_))]\n",
    "        indices.append(falses + list_)\n",
    "    return [np.array(x)[j] for j in indices]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[1], [3], [4], [1, 3], [1, 4], [3, 4]]"
      ]
     },
     "execution_count": 80,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def all_rows_equal(X):\n",
    "    return (X == X[0]).all()\n",
    "\n",
    "def all_possible_splits(x):\n",
    "    L_values = []\n",
    "    for i in range(1, len(x)):\n",
    "        for combo in combinations(x, i):\n",
    "            L_values.append(list(combo))\n",
    "    return L_values\n",
    "\n",
    "all_possible_splits([1,3,4])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Node:\n",
    "    \n",
    "    def __init__(self, Xsub, ysub, ID, depth = 0, parent_ID = None, leaf = True):\n",
    "        self.ID = ID\n",
    "        self.Xsub = Xsub\n",
    "        self.ysub = ysub\n",
    "        self.size = len(ysub)\n",
    "        self.depth = depth\n",
    "        self.parent_ID = parent_ID\n",
    "        self.leaf = leaf\n",
    "        \n",
    "\n",
    "class Splitter:\n",
    "    \n",
    "    def __init__(self):\n",
    "        self.loss = np.inf\n",
    "        \n",
    "    def replace_split(self, loss, parent_ID, d, dtype = 'quant', t = None, L_values = None):\n",
    "        self.loss = loss\n",
    "        self.parent_ID = parent_ID\n",
    "        self.d = d\n",
    "        self.dtype = dtype\n",
    "        self.t = t\n",
    "        self.L_values = L_values        "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 82,
   "metadata": {},
   "outputs": [],
   "source": [
    "class DecisionTreeClassifier:\n",
    "    \n",
    "    #############################\n",
    "    ######## 1. TRAINING ########\n",
    "    #############################\n",
    "    \n",
    "    ######### FIT ##########\n",
    "    def fit(self, X, y, loss_func = cross_entropy, max_depth = 100, min_size = 2):\n",
    "        \n",
    "        ## Add data\n",
    "        self.X = X\n",
    "        self.y = y\n",
    "        self.N, self.D = self.X.shape\n",
    "        dtypes = [np.array(list(self.X[:,d])).dtype for d in range(self.D)]\n",
    "        self.dtypes = ['quant' if (dtype == float or dtype == int) else 'cat' for dtype in dtypes]\n",
    "\n",
    "        ## Add model parameters\n",
    "        self.loss_func = loss_func\n",
    "        self.max_depth = max_depth\n",
    "        self.min_size = min_size\n",
    "        \n",
    "        ## Initialize nodes\n",
    "        self.nodes_dict = {}\n",
    "        self.current_ID = 0\n",
    "        initial_node = Node(Xsub = X, ysub = y, ID = self.current_ID, parent_ID = None)\n",
    "        self.nodes_dict[self.current_ID] = initial_node\n",
    "        self.current_ID += 1\n",
    "        \n",
    "        # Build\n",
    "        self.build()\n",
    "        \n",
    "        # Calculate leaf modes\n",
    "        self.get_leaf_modes()\n",
    "     \n",
    "    \n",
    "    ###### FIND SPLIT ######\n",
    "    def find_split(self, eligible_parents):\n",
    "        \n",
    "        ## Instantiate splitter\n",
    "        splitter = Splitter()\n",
    "        \n",
    "        ## For each eligible parent node...\n",
    "        for parent_ID, parent in eligible_parents.items():\n",
    "            ysub = parent.ysub\n",
    "            \n",
    "            ## For each predictor...\n",
    "            for d in range(self.D):\n",
    "                Xsub_d = parent.Xsub[:,d]\n",
    "                dtype = self.dtypes[d]\n",
    "                if len(np.unique(Xsub_d)) == 1:\n",
    "                    continue\n",
    "                    \n",
    "                ## For each value...\n",
    "                if dtype == 'quant':\n",
    "                    for t in np.unique(Xsub_d)[:-1]:\n",
    "                        ysub_L = ysub[Xsub_d <= t]\n",
    "                        ysub_R = ysub[Xsub_d > t]\n",
    "                        loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)\n",
    "                        if loss < splitter.loss:\n",
    "                            splitter.replace_split(loss, parent_ID, d, 'quant', t = t)\n",
    "                else:\n",
    "                    for L_values in all_possible_splits(np.unique(Xsub_d)):\n",
    "                        ysub_L = ysub[np.isin(Xsub_d, L_values)]\n",
    "                        ysub_R = ysub[~np.isin(Xsub_d, L_values)]\n",
    "                        loss = split_loss(ysub_L, ysub_R, loss = self.loss_func)\n",
    "                        if loss < splitter.loss: \n",
    "                            splitter.replace_split(loss, parent_ID, d, 'cat', L_values = L_values)\n",
    "        ## Save splitter\n",
    "        self.splitter = splitter\n",
    "    \n",
    "    ###### MAKE SPLIT ######\n",
    "    def make_split(self):\n",
    "        ## Update parent nodes\n",
    "        parent_node = self.nodes_dict[self.splitter.parent_ID]\n",
    "        parent_node.leaf = False\n",
    "        parent_node.child_L = self.current_ID\n",
    "        parent_node.child_R = self.current_ID + 1\n",
    "        parent_node.d = self.splitter.d\n",
    "        parent_node.dtype = self.splitter.dtype\n",
    "        parent_node.t = self.splitter.t\n",
    "        parent_node.L_values = self.splitter.L_values\n",
    "        \n",
    "        ## Get X and y data for children\n",
    "        if parent_node.dtype == 'quant':\n",
    "            L_condition = parent_node.Xsub[:,parent_node.d] <= parent_node.t\n",
    "     \n",
    "        else:\n",
    "            L_condition = np.isin(parent_node.Xsub[:,parent_node.d], parent_node.L_values)\n",
    "        Xchild_L = parent_node.Xsub[L_condition]\n",
    "        ychild_L = parent_node.ysub[L_condition]\n",
    "        Xchild_R = parent_node.Xsub[~L_condition]\n",
    "        ychild_R = parent_node.ysub[~L_condition]\n",
    "\n",
    "\n",
    "        \n",
    "        ## Create child nodes\n",
    "        child_node_L = Node(Xchild_L, ychild_L, depth = parent_node.depth + 1,\n",
    "                            ID = self.current_ID, parent_ID = parent_node.ID)\n",
    "        child_node_R = Node(Xchild_R, ychild_R, depth = parent_node.depth + 1,\n",
    "                            ID = self.current_ID+1, parent_ID = parent_node.ID)\n",
    "        self.nodes_dict[self.current_ID] = child_node_L\n",
    "        self.nodes_dict[self.current_ID + 1] = child_node_R\n",
    "        self.current_ID += 2\n",
    "    \n",
    "    ###### BUILD TREE ######\n",
    "    def build(self):\n",
    "        \n",
    "        eligible_parents = self.nodes_dict\n",
    "        while True:\n",
    "                                    \n",
    "            ## Find split among eligible parent nodes\n",
    "            self.find_split(eligible_parents)\n",
    "            \n",
    "            ## Make split\n",
    "            self.make_split()\n",
    "            \n",
    "            ## Find eligible nodes for next iteration\n",
    "            eligible_parents = {ID:node for (ID, node) in self.nodes_dict.items() if \n",
    "                                (node.leaf == True) &\n",
    "                                (node.depth < self.max_depth) &\n",
    "                                (node.size >= self.min_size) & \n",
    "                                (~all_rows_equal(node.Xsub))}\n",
    "            \n",
    "            ## Quit if no more eligible parents\n",
    "            if len(eligible_parents) == 0:\n",
    "                break\n",
    "                \n",
    "                \n",
    "    ###### LEAF MEANS ######\n",
    "    def get_leaf_modes(self):\n",
    "        self.leaf_modes = {}\n",
    "        for node_ID, node in self.nodes_dict.items():\n",
    "            if node.leaf:\n",
    "                values, counts = np.unique(node.ysub, return_counts=True)\n",
    "                self.leaf_modes[node_ID] = values[np.argmax(counts)]\n",
    "            \n",
    "    #############################\n",
    "    ####### 2. PREDICTING #######\n",
    "    #############################\n",
    "    \n",
    "    ####### PREDICT ########\n",
    "    def predict(self, X_test):\n",
    "        yhat = []\n",
    "        for x in X_test:\n",
    "            node = self.nodes_dict[0] \n",
    "            while not node.leaf:\n",
    "                if node.dtype == 'quant':\n",
    "                    if x[node.d] <= node.t:\n",
    "                        node = self.nodes_dict[node.child_L]\n",
    "                    else:\n",
    "                        node = self.nodes_dict[node.child_R]\n",
    "                else:\n",
    "                    if x[node.d] in node.L_values:\n",
    "                        node = self.nodes_dict[node.child_L]\n",
    "                    else:\n",
    "                        node = self.nodes_dict[node.child_R]\n",
    "            yhat.append(self.leaf_modes[node.ID])\n",
    "        return np.array(yhat)\n",
    "            \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 83,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_frac = 0.25\n",
    "test_size = int(len(y)*test_frac)\n",
    "test_idxs = np.random.choice(np.arange(len(y)), test_size, replace = False)\n",
    "X_train = X[~test_idxs]\n",
    "y_train = y[~test_idxs]\n",
    "X_test = X[test_idxs]\n",
    "y_test = y[test_idxs]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 84,
   "metadata": {},
   "outputs": [],
   "source": [
    "tree = DecisionTreeClassifier()\n",
    "tree.fit(X_train, y_train, max_depth = 10, min_size = 10)\n",
    "y_test_hat = tree.predict(X_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.963855421686747"
      ]
     },
     "execution_count": 85,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.mean(y_test_hat == y_test)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Edit Metadata",
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
